{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class Tensor (object):\n",
    "    \n",
    "    def __init__(self,data,\n",
    "                 autograd=False,\n",
    "                 creators=None,\n",
    "                 creation_op=None,\n",
    "                 id=None):\n",
    "        \n",
    "        self.data = np.array(data)\n",
    "        self.autograd = autograd\n",
    "        self.grad = None\n",
    "\n",
    "        if(id is None):\n",
    "            self.id = np.random.randint(0,1000000000)\n",
    "        else:\n",
    "            self.id = id\n",
    "        \n",
    "        self.creators = creators\n",
    "        self.creation_op = creation_op\n",
    "        self.children = {}\n",
    "        \n",
    "        if(creators is not None):\n",
    "            for c in creators:\n",
    "                if(self.id not in c.children):\n",
    "                    c.children[self.id] = 1\n",
    "                else:\n",
    "                    c.children[self.id] += 1\n",
    "\n",
    "    def all_children_grads_accounted_for(self):\n",
    "        for id,cnt in self.children.items():\n",
    "            if(cnt != 0):\n",
    "                return False\n",
    "        return True \n",
    "        \n",
    "    def backward(self,grad=None, grad_origin=None):\n",
    "        if(self.autograd):\n",
    " \n",
    "            if(grad is None):\n",
    "                grad = Tensor(np.ones_like(self.data))\n",
    "\n",
    "            if(grad_origin is not None):\n",
    "                if(self.children[grad_origin.id] == 0):\n",
    "                    return\n",
    "                    print(self.id)\n",
    "                    print(self.creation_op)\n",
    "                    print(len(self.creators))\n",
    "                    for c in self.creators:\n",
    "                        print(c.creation_op)\n",
    "                    raise Exception(\"cannot backprop more than once\")\n",
    "                else:\n",
    "                    self.children[grad_origin.id] -= 1\n",
    "\n",
    "            if(self.grad is None):\n",
    "                self.grad = grad\n",
    "            else:\n",
    "                self.grad += grad\n",
    "            \n",
    "            # grads must not have grads of their own\n",
    "            assert grad.autograd == False\n",
    "            \n",
    "            # only continue backpropping if there's something to\n",
    "            # backprop into and if all gradients (from children)\n",
    "            # are accounted for override waiting for children if\n",
    "            # \"backprop\" was called on this variable directly\n",
    "            if(self.creators is not None and \n",
    "               (self.all_children_grads_accounted_for() or \n",
    "                grad_origin is None)):\n",
    "\n",
    "                if(self.creation_op == \"add\"):\n",
    "                    self.creators[0].backward(self.grad, self)\n",
    "                    self.creators[1].backward(self.grad, self)\n",
    "                    \n",
    "                if(self.creation_op == \"sub\"):\n",
    "                    self.creators[0].backward(Tensor(self.grad.data), self)\n",
    "                    self.creators[1].backward(Tensor(self.grad.__neg__().data), self)\n",
    "\n",
    "                if(self.creation_op == \"mul\"):\n",
    "                    new = self.grad * self.creators[1]\n",
    "                    self.creators[0].backward(new , self)\n",
    "                    new = self.grad * self.creators[0]\n",
    "                    self.creators[1].backward(new, self)                    \n",
    "                    \n",
    "                if(self.creation_op == \"mm\"):\n",
    "                    c0 = self.creators[0]\n",
    "                    c1 = self.creators[1]\n",
    "                    new = self.grad.mm(c1.transpose())\n",
    "                    c0.backward(new)\n",
    "                    new = self.grad.transpose().mm(c0).transpose()\n",
    "                    c1.backward(new)\n",
    "                    \n",
    "                if(self.creation_op == \"transpose\"):\n",
    "                    self.creators[0].backward(self.grad.transpose())\n",
    "\n",
    "                if(\"sum\" in self.creation_op):\n",
    "                    dim = int(self.creation_op.split(\"_\")[1])\n",
    "                    self.creators[0].backward(self.grad.expand(dim,\n",
    "                                                               self.creators[0].data.shape[dim]))\n",
    "\n",
    "                if(\"expand\" in self.creation_op):\n",
    "                    dim = int(self.creation_op.split(\"_\")[1])\n",
    "                    self.creators[0].backward(self.grad.sum(dim))\n",
    "                    \n",
    "                if(self.creation_op == \"neg\"):\n",
    "                    self.creators[0].backward(self.grad.__neg__())\n",
    "                    \n",
    "                if(self.creation_op == \"sigmoid\"):\n",
    "                    ones = Tensor(np.ones_like(self.grad.data))\n",
    "                    self.creators[0].backward(self.grad * (self * (ones - self)))\n",
    "                \n",
    "                if(self.creation_op == \"tanh\"):\n",
    "                    ones = Tensor(np.ones_like(self.grad.data))\n",
    "                    self.creators[0].backward(self.grad * (ones - (self * self)))\n",
    "                \n",
    "                if(self.creation_op == \"index_select\"):\n",
    "                    new_grad = np.zeros_like(self.creators[0].data)\n",
    "                    indices_ = self.index_select_indices.data.flatten()\n",
    "                    grad_ = grad.data.reshape(len(indices_), -1)\n",
    "                    for i in range(len(indices_)):\n",
    "                        new_grad[indices_[i]] += grad_[i]\n",
    "                    self.creators[0].backward(Tensor(new_grad))\n",
    "                    \n",
    "                if(self.creation_op == \"cross_entropy\"):\n",
    "                    dx = self.softmax_output - self.target_dist\n",
    "                    self.creators[0].backward(Tensor(dx))\n",
    "                    \n",
    "    def __add__(self, other):\n",
    "        if(self.autograd and other.autograd):\n",
    "            return Tensor(self.data + other.data,\n",
    "                          autograd=True,\n",
    "                          creators=[self,other],\n",
    "                          creation_op=\"add\")\n",
    "        return Tensor(self.data + other.data)\n",
    "\n",
    "    def __neg__(self):\n",
    "        if(self.autograd):\n",
    "            return Tensor(self.data * -1,\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"neg\")\n",
    "        return Tensor(self.data * -1)\n",
    "    \n",
    "    def __sub__(self, other):\n",
    "        if(self.autograd and other.autograd):\n",
    "            return Tensor(self.data - other.data,\n",
    "                          autograd=True,\n",
    "                          creators=[self,other],\n",
    "                          creation_op=\"sub\")\n",
    "        return Tensor(self.data - other.data)\n",
    "    \n",
    "    def __mul__(self, other):\n",
    "        if(self.autograd and other.autograd):\n",
    "            return Tensor(self.data * other.data,\n",
    "                          autograd=True,\n",
    "                          creators=[self,other],\n",
    "                          creation_op=\"mul\")\n",
    "        return Tensor(self.data * other.data)    \n",
    "\n",
    "    def sum(self, dim):\n",
    "        if(self.autograd):\n",
    "            return Tensor(self.data.sum(dim),\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"sum_\"+str(dim))\n",
    "        return Tensor(self.data.sum(dim))\n",
    "    \n",
    "    def expand(self, dim,copies):\n",
    "\n",
    "        trans_cmd = list(range(0,len(self.data.shape)))\n",
    "        trans_cmd.insert(dim,len(self.data.shape))\n",
    "        new_data = self.data.repeat(copies).reshape(list(self.data.shape) + [copies]).transpose(trans_cmd)\n",
    "        \n",
    "        if(self.autograd):\n",
    "            return Tensor(new_data,\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"expand_\"+str(dim))\n",
    "        return Tensor(new_data)\n",
    "    \n",
    "    def transpose(self):\n",
    "        if(self.autograd):\n",
    "            return Tensor(self.data.transpose(),\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"transpose\")\n",
    "        \n",
    "        return Tensor(self.data.transpose())\n",
    "    \n",
    "    def mm(self, x):\n",
    "        if(self.autograd):\n",
    "            return Tensor(self.data.dot(x.data),\n",
    "                          autograd=True,\n",
    "                          creators=[self,x],\n",
    "                          creation_op=\"mm\")\n",
    "        return Tensor(self.data.dot(x.data))\n",
    "    \n",
    "    def sigmoid(self):\n",
    "        if(self.autograd):\n",
    "            return Tensor(1 / (1 + np.exp(-self.data)),\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"sigmoid\")\n",
    "        return Tensor(1 / (1 + np.exp(-self.data)))\n",
    "\n",
    "    def tanh(self):\n",
    "        if(self.autograd):\n",
    "            return Tensor(np.tanh(self.data),\n",
    "                          autograd=True,\n",
    "                          creators=[self],\n",
    "                          creation_op=\"tanh\")\n",
    "        return Tensor(np.tanh(self.data))\n",
    "    \n",
    "    def index_select(self, indices):\n",
    "\n",
    "        if(self.autograd):\n",
    "            new = Tensor(self.data[indices.data],\n",
    "                         autograd=True,\n",
    "                         creators=[self],\n",
    "                         creation_op=\"index_select\")\n",
    "            new.index_select_indices = indices\n",
    "            return new\n",
    "        return Tensor(self.data[indices.data])\n",
    "    \n",
    "    def softmax(self):\n",
    "        temp = np.exp(self.data)\n",
    "        softmax_output = temp / np.sum(temp,\n",
    "                                       axis=len(self.data.shape)-1,\n",
    "                                       keepdims=True)\n",
    "        return softmax_output\n",
    "    \n",
    "    def cross_entropy(self, target_indices):\n",
    "\n",
    "        temp = np.exp(self.data)\n",
    "        softmax_output = temp / np.sum(temp,\n",
    "                                       axis=len(self.data.shape)-1,\n",
    "                                       keepdims=True)\n",
    "        \n",
    "        t = target_indices.data.flatten()\n",
    "        p = softmax_output.reshape(len(t),-1)\n",
    "        target_dist = np.eye(p.shape[1])[t]\n",
    "        loss = -(np.log(p) * (target_dist)).sum(1).mean()\n",
    "    \n",
    "        if(self.autograd):\n",
    "            out = Tensor(loss,\n",
    "                         autograd=True,\n",
    "                         creators=[self],\n",
    "                         creation_op=\"cross_entropy\")\n",
    "            out.softmax_output = softmax_output\n",
    "            out.target_dist = target_dist\n",
    "            return out\n",
    "\n",
    "        return Tensor(loss)\n",
    "        \n",
    "    \n",
    "    def __repr__(self):\n",
    "        return str(self.data.__repr__())\n",
    "    \n",
    "    def __str__(self):\n",
    "        return str(self.data.__str__())  \n",
    "\n",
    "class Layer(object):\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.parameters = list()\n",
    "        \n",
    "    def get_parameters(self):\n",
    "        return self.parameters\n",
    "\n",
    "    \n",
    "class SGD(object):\n",
    "    \n",
    "    def __init__(self, parameters, alpha=0.1):\n",
    "        self.parameters = parameters\n",
    "        self.alpha = alpha\n",
    "    \n",
    "    def zero(self):\n",
    "        for p in self.parameters:\n",
    "            p.grad.data *= 0\n",
    "        \n",
    "    def step(self, zero=True):\n",
    "        \n",
    "        for p in self.parameters:\n",
    "            \n",
    "            p.data -= p.grad.data * self.alpha\n",
    "            \n",
    "            if(zero):\n",
    "                p.grad.data *= 0\n",
    "\n",
    "\n",
    "class Linear(Layer):\n",
    "\n",
    "    def __init__(self, n_inputs, n_outputs, bias=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.use_bias = bias\n",
    "        \n",
    "        W = np.random.randn(n_inputs, n_outputs) * np.sqrt(2.0/(n_inputs))\n",
    "        self.weight = Tensor(W, autograd=True)\n",
    "        if(self.use_bias):\n",
    "            self.bias = Tensor(np.zeros(n_outputs), autograd=True)\n",
    "        \n",
    "        self.parameters.append(self.weight)\n",
    "        \n",
    "        if(self.use_bias):        \n",
    "            self.parameters.append(self.bias)\n",
    "\n",
    "    def forward(self, input):\n",
    "        if(self.use_bias):\n",
    "            return input.mm(self.weight)+self.bias.expand(0,len(input.data))\n",
    "        return input.mm(self.weight)\n",
    "\n",
    "\n",
    "class Sequential(Layer):\n",
    "    \n",
    "    def __init__(self, layers=list()):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.layers = layers\n",
    "    \n",
    "    def add(self, layer):\n",
    "        self.layers.append(layer)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        for layer in self.layers:\n",
    "            input = layer.forward(input)\n",
    "        return input\n",
    "    \n",
    "    def get_parameters(self):\n",
    "        params = list()\n",
    "        for l in self.layers:\n",
    "            params += l.get_parameters()\n",
    "        return params\n",
    "\n",
    "\n",
    "class Embedding(Layer):\n",
    "    \n",
    "    def __init__(self, vocab_size, dim):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.vocab_size = vocab_size\n",
    "        self.dim = dim\n",
    "        \n",
    "        # this random initialiation style is just a convention from word2vec\n",
    "        self.weight = Tensor((np.random.rand(vocab_size, dim) - 0.5) / dim, autograd=True)\n",
    "        \n",
    "        self.parameters.append(self.weight)\n",
    "    \n",
    "    def forward(self, input):\n",
    "        return self.weight.index_select(input)\n",
    "\n",
    "\n",
    "class Tanh(Layer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, input):\n",
    "        return input.tanh()\n",
    "\n",
    "\n",
    "class Sigmoid(Layer):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, input):\n",
    "        return input.sigmoid()\n",
    "    \n",
    "\n",
    "class CrossEntropyLoss(object):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, input, target):\n",
    "        return input.cross_entropy(target)\n",
    "\n",
    "class MSELoss(object):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    def forward(self, input, target):\n",
    "        dif = input - target\n",
    "        return (dif * dif).sum(0)\n",
    "    \n",
    "class RNNCell(Layer):\n",
    "    \n",
    "    def __init__(self, n_inputs, n_hidden, n_output, activation='sigmoid'):\n",
    "        super().__init__()\n",
    "\n",
    "        self.n_inputs = n_inputs\n",
    "        self.n_hidden = n_hidden\n",
    "        self.n_output = n_output\n",
    "        \n",
    "        if(activation == 'sigmoid'):\n",
    "            self.activation = Sigmoid()\n",
    "        elif(activation == 'tanh'):\n",
    "            self.activation == Tanh()\n",
    "        else:\n",
    "            raise Exception(\"Non-linearity not found\")\n",
    "\n",
    "        self.w_ih = Linear(n_inputs, n_hidden)\n",
    "        self.w_hh = Linear(n_hidden, n_hidden)\n",
    "        self.w_ho = Linear(n_hidden, n_output)\n",
    "        \n",
    "        self.parameters += self.w_ih.get_parameters()\n",
    "        self.parameters += self.w_hh.get_parameters()\n",
    "        self.parameters += self.w_ho.get_parameters()        \n",
    "    \n",
    "    def forward(self, input, hidden):\n",
    "        from_prev_hidden = self.w_hh.forward(hidden)\n",
    "        combined = self.w_ih.forward(input) + from_prev_hidden\n",
    "        new_hidden = self.activation.forward(combined)\n",
    "        output = self.w_ho.forward(new_hidden)\n",
    "        return output, new_hidden\n",
    "    \n",
    "    def init_hidden(self, batch_size=1):\n",
    "        return Tensor(np.zeros((batch_size,self.n_hidden)), autograd=True)\n",
    "    \n",
    "class LSTMCell(Layer):\n",
    "    \n",
    "    def __init__(self, n_inputs, n_hidden, n_output):\n",
    "        super().__init__()\n",
    "\n",
    "        self.n_inputs = n_inputs\n",
    "        self.n_hidden = n_hidden\n",
    "        self.n_output = n_output\n",
    "\n",
    "        self.xf = Linear(n_inputs, n_hidden)\n",
    "        self.xi = Linear(n_inputs, n_hidden)\n",
    "        self.xo = Linear(n_inputs, n_hidden)        \n",
    "        self.xc = Linear(n_inputs, n_hidden)        \n",
    "        \n",
    "        self.hf = Linear(n_hidden, n_hidden, bias=False)\n",
    "        self.hi = Linear(n_hidden, n_hidden, bias=False)\n",
    "        self.ho = Linear(n_hidden, n_hidden, bias=False)\n",
    "        self.hc = Linear(n_hidden, n_hidden, bias=False)        \n",
    "        \n",
    "        self.w_ho = Linear(n_hidden, n_output, bias=False)\n",
    "        \n",
    "        self.parameters += self.xf.get_parameters()\n",
    "        self.parameters += self.xi.get_parameters()\n",
    "        self.parameters += self.xo.get_parameters()\n",
    "        self.parameters += self.xc.get_parameters()\n",
    "\n",
    "        self.parameters += self.hf.get_parameters()\n",
    "        self.parameters += self.hi.get_parameters()        \n",
    "        self.parameters += self.ho.get_parameters()        \n",
    "        self.parameters += self.hc.get_parameters()                \n",
    "        \n",
    "        self.parameters += self.w_ho.get_parameters()        \n",
    "    \n",
    "    def forward(self, input, hidden):\n",
    "        \n",
    "        prev_hidden = hidden[0]        \n",
    "        prev_cell = hidden[1]\n",
    "        \n",
    "        f = (self.xf.forward(input) + self.hf.forward(prev_hidden)).sigmoid()\n",
    "        i = (self.xi.forward(input) + self.hi.forward(prev_hidden)).sigmoid()\n",
    "        o = (self.xo.forward(input) + self.ho.forward(prev_hidden)).sigmoid()        \n",
    "        g = (self.xc.forward(input) + self.hc.forward(prev_hidden)).tanh()        \n",
    "        c = (f * prev_cell) + (i * g)\n",
    "\n",
    "        h = o * c.tanh()\n",
    "        \n",
    "        output = self.w_ho.forward(h)\n",
    "        return output, (h, c)\n",
    "    \n",
    "    def init_hidden(self, batch_size=1):\n",
    "        init_hidden = Tensor(np.zeros((batch_size,self.n_hidden)), autograd=True)\n",
    "        init_cell = Tensor(np.zeros((batch_size,self.n_hidden)), autograd=True)\n",
    "        init_hidden.data[:,0] += 1\n",
    "        init_cell.data[:,0] += 1\n",
    "        return (init_hidden, init_cell)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 1: Plan Ole Fashioned Deep Learning (Email Spam Detection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 442,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import Counter\n",
    "import random\n",
    "import sys\n",
    "np.random.seed(12345)\n",
    "\n",
    "# dataset from http://www2.aueb.gr/users/ion/data/enron-spam/\n",
    "\n",
    "import codecs\n",
    "with codecs.open('spam.txt', \"r\",encoding='utf-8', errors='ignore') as fdata:\n",
    "    raw = fdata.readlines()\n",
    "\n",
    "vocab = set()\n",
    "    \n",
    "spam = list()\n",
    "for row in raw:\n",
    "    spam.append(set(row[:-2].split(\" \")))\n",
    "    for word in spam[-1]:\n",
    "        vocab.add(word)\n",
    "    \n",
    "import codecs\n",
    "with codecs.open('ham.txt', \"r\",encoding='utf-8', errors='ignore') as fdata:\n",
    "    raw = fdata.readlines()\n",
    "\n",
    "ham = list()\n",
    "for row in raw:\n",
    "    ham.append(set(row[:-2].split(\" \")))\n",
    "    for word in ham[-1]:\n",
    "        vocab.add(word)\n",
    "        \n",
    "vocab.add(\"<unk>\")\n",
    "\n",
    "vocab = list(vocab)\n",
    "w2i = {}\n",
    "for i,w in enumerate(vocab):\n",
    "    w2i[w] = i\n",
    "    \n",
    "def to_indices(input, l=500):\n",
    "    indices = list()\n",
    "    for line in input:\n",
    "        if(len(line) < l):\n",
    "            line = list(line) + [\"<unk>\"] * (l - len(line))\n",
    "            idxs = list()\n",
    "            for word in line:\n",
    "                idxs.append(w2i[word])\n",
    "            indices.append(idxs)\n",
    "    return indices\n",
    "            \n",
    "spam_idx = to_indices(spam)\n",
    "ham_idx = to_indices(ham)\n",
    "\n",
    "train_spam_idx = spam_idx[0:-1000]\n",
    "train_ham_idx = ham_idx[0:-1000]\n",
    "\n",
    "test_spam_idx = spam_idx[-1000:]\n",
    "test_ham_idx = ham_idx[-1000:]\n",
    "\n",
    "train_data = list()\n",
    "train_target = list()\n",
    "\n",
    "test_data = list()\n",
    "test_target = list()\n",
    "\n",
    "for i in range(max(len(train_spam_idx),len(train_ham_idx))):\n",
    "    train_data.append(train_spam_idx[i%len(train_spam_idx)])\n",
    "    train_target.append([1])\n",
    "    \n",
    "    train_data.append(train_ham_idx[i%len(train_ham_idx)])\n",
    "    train_target.append([0])\n",
    "    \n",
    "for i in range(max(len(test_spam_idx),len(test_ham_idx))):\n",
    "    test_data.append(test_spam_idx[i%len(test_spam_idx)])\n",
    "    test_target.append([1])\n",
    "    \n",
    "    test_data.append(test_ham_idx[i%len(test_ham_idx)])\n",
    "    test_target.append([0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 457,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, input_data, target_data, batch_size=500, iterations=5):\n",
    "    \n",
    "    criterion = MSELoss()\n",
    "    optim = SGD(parameters=model.get_parameters(), alpha=0.01)\n",
    "    \n",
    "    n_batches = int(len(input_data) / batch_size)\n",
    "    for iter in range(iterations):\n",
    "        iter_loss = 0\n",
    "        for b_i in range(n_batches):\n",
    "\n",
    "            # padding token should stay at 0\n",
    "            model.weight.data[w2i['<unk>']] *= 0 \n",
    "            input = Tensor(input_data[b_i*bs:(b_i+1)*bs], autograd=True)\n",
    "            target = Tensor(target_data[b_i*bs:(b_i+1)*bs], autograd=True)\n",
    "\n",
    "            pred = model.forward(input).sum(1).sigmoid()\n",
    "            loss = criterion.forward(pred,target)\n",
    "            loss.backward()\n",
    "            optim.step()\n",
    "\n",
    "            iter_loss += loss.data[0] / bs\n",
    "\n",
    "            sys.stdout.write(\"\\r\\tLoss:\" + str(iter_loss / (b_i+1)))\n",
    "        print()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 458,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, test_input, test_output):\n",
    "    \n",
    "    model.weight.data[w2i['<unk>']] *= 0 \n",
    "    \n",
    "    input = Tensor(test_input, autograd=True)\n",
    "    target = Tensor(test_output, autograd=True)\n",
    "\n",
    "    pred = model.forward(input).sum(1).sigmoid()\n",
    "    return ((pred.data > 0.5) == target.data).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 459,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Embedding(vocab_size=len(vocab), dim=1)\n",
    "model.weight.data *= 0\n",
    "criterion = MSELoss()\n",
    "optim = SGD(parameters=model.get_parameters(), alpha=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 446,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(3):\n",
    "    model = train(model, train_data, train_target, iterations=1)\n",
    "    print(\"% Correct on Test Set: \" + str(test(model, test_data, test_target)*100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Federated Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 464,
   "metadata": {},
   "outputs": [],
   "source": [
    "bob = (train_data[0:1000], train_target[0:1000])\n",
    "alice = (train_data[1000:2000], train_target[1000:2000])\n",
    "sue = (train_data[2000:], train_target[2000:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 465,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Embedding(vocab_size=len(vocab), dim=1)\n",
    "model.weight.data *= 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 466,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.21908166249699718\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.2937106899184867\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.033339966977175894\n",
      "\n",
      "\tAverage Everyone's New Models\n",
      "\t% Correct on Test Set: 84.05\n",
      "\n",
      "Repeat!!\n",
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.06625367483630413\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.09595374225556821\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.020290247881140743\n",
      "\n",
      "\tAverage Everyone's New Models\n",
      "\t% Correct on Test Set: 92.25\n",
      "\n",
      "Repeat!!\n",
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.030819682914453833\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.03580324891736099\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.015368461608470256\n",
      "\n",
      "\tAverage Everyone's New Models\n",
      "\t% Correct on Test Set: 98.8\n",
      "\n",
      "Repeat!!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    print(\"Starting Training Round...\")\n",
    "    print(\"\\tStep 1: send the model to Bob\")\n",
    "    bob_model = train(copy.deepcopy(model), bob[0], bob[1], iterations=1)\n",
    "    \n",
    "    print(\"\\n\\tStep 2: send the model to Alice\")\n",
    "    alice_model = train(copy.deepcopy(model), alice[0], alice[1], iterations=1)\n",
    "    \n",
    "    print(\"\\n\\tStep 3: Send the model to Sue\")\n",
    "    sue_model = train(copy.deepcopy(model), sue[0], sue[1], iterations=1)\n",
    "    \n",
    "    print(\"\\n\\tAverage Everyone's New Models\")\n",
    "    model.weight.data = (bob_model.weight.data + \\\n",
    "                         alice_model.weight.data + \\\n",
    "                         sue_model.weight.data)/3\n",
    "    \n",
    "    print(\"\\t% Correct on Test Set: \" + \\\n",
    "          str(test(model, test_data, test_target)*100))\n",
    "    \n",
    "    print(\"\\nRepeat!!\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hacking Federated Learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 468,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tLoss:0.0005\n"
     ]
    }
   ],
   "source": [
    "import copy\n",
    "\n",
    "bobs_email = [\"my\", \"computer\", \"password\", \"is\", \"pizza\"]\n",
    "\n",
    "bob_input = np.array([[w2i[x] for x in bobs_email]])\n",
    "bob_target = np.array([[0]])\n",
    "\n",
    "model = Embedding(vocab_size=len(vocab), dim=1)\n",
    "model.weight.data *= 0\n",
    "\n",
    "bobs_model = train(copy.deepcopy(model), bob_input, bob_target, iterations=1, batch_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 469,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is\n",
      "pizza\n",
      "computer\n",
      "password\n",
      "my\n"
     ]
    }
   ],
   "source": [
    "for i, v in enumerate(bobs_model.weight.data - model.weight.data):\n",
    "    if(v != 0):\n",
    "        print(vocab[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homomorphic Encryption"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 485,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Answer: 8\n"
     ]
    }
   ],
   "source": [
    "import phe\n",
    "\n",
    "public_key, private_key = phe.generate_paillier_keypair(n_length=1024)\n",
    "\n",
    "# encrypt the number \"5\"\n",
    "x = public_key.encrypt(5)\n",
    "\n",
    "# encrypt the number \"3\"\n",
    "y = public_key.encrypt(3)\n",
    "\n",
    "# add the two encrypted values\n",
    "z = x + y\n",
    "\n",
    "# decrypt the result\n",
    "z_ = private_key.decrypt(z)\n",
    "print(\"The Answer: \" + str(z_))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Secure Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 567,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Embedding(vocab_size=len(vocab), dim=1)\n",
    "model.weight.data *= 0\n",
    "\n",
    "# note that in production the n_length should be at least 1024\n",
    "public_key, private_key = phe.generate_paillier_keypair(n_length=128)\n",
    "\n",
    "def train_and_encrypt(model, input, target, pubkey):\n",
    "    new_model = train(copy.deepcopy(model), input, target, iterations=1)\n",
    "\n",
    "    encrypted_weights = list()\n",
    "    for val in new_model.weight.data[:,0]:\n",
    "        encrypted_weights.append(public_key.encrypt(val))\n",
    "    ew = np.array(encrypted_weights).reshape(new_model.weight.data.shape)\n",
    "    \n",
    "    return ew"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 568,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 569,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.21908166249699718\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.2937106899184867\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.033339966977175894\n",
      "\n",
      "\tStep 4: Bob, Alice, and Sue send their\n",
      "\tencrypted models to each other.\n",
      "\n",
      "\tStep 5: only the aggregated model\n",
      "\n",
      "\tis sent back to the model owner who\n",
      "\n",
      "\t can decrypt it.\n",
      "\t% Correct on Test Set: 98.75\n",
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.063664140530356044\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.06100035791351306\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.025483920416627266\n",
      "\n",
      "\tStep 4: Bob, Alice, and Sue send their\n",
      "\tencrypted models to each other.\n",
      "\n",
      "\tStep 5: only the aggregated model\n",
      "\n",
      "\tis sent back to the model owner who\n",
      "\n",
      "\t can decrypt it.\n",
      "\t% Correct on Test Set: 99.05000000000001\n",
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.058477976535441636\n",
      "\n",
      "\tStep 2: send the model to Alice\n",
      "\tLoss:0.05987976552444443\n",
      "\n",
      "\tStep 3: Send the model to Sue\n",
      "\tLoss:0.024763428511034746\n",
      "\n",
      "\tStep 4: Bob, Alice, and Sue send their\n",
      "\tencrypted models to each other.\n",
      "\n",
      "\tStep 5: only the aggregated model\n",
      "\n",
      "\tis sent back to the model owner who\n",
      "\n",
      "\t can decrypt it.\n",
      "\t% Correct on Test Set: 99.15\n",
      "\n",
      "Starting Training Round...\n",
      "\tStep 1: send the model to Bob\n",
      "\tLoss:0.0579450413900613\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-569-a3a66f9d95e5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\tStep 1: send the model to Bob\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     bob_encrypted_model = train_and_encrypt(copy.deepcopy(model), \n\u001b[0;32m----> 5\u001b[0;31m                                             bob[0], bob[1], public_key)\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n\\tStep 2: send the model to Alice\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-568-5c89678f90a5>\u001b[0m in \u001b[0;36mtrain_and_encrypt\u001b[0;34m(model, input, target, pubkey)\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mencrypted_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mnew_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m         \u001b[0mencrypted_weights\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpublic_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencrypt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m     \u001b[0mew\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencrypted_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/atrask/anaconda/lib/python3.6/site-packages/phe-1.3.0-py3.6.egg/phe/paillier.py\u001b[0m in \u001b[0;36mencrypt\u001b[0;34m(self, value, precision, r_value)\u001b[0m\n\u001b[1;32m    171\u001b[0m             \u001b[0mencoding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncodedNumber\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprecision\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 173\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencrypt_encoded\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    174\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    175\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mencrypt_encoded\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/atrask/anaconda/lib/python3.6/site-packages/phe-1.3.0-py3.6.egg/phe/paillier.py\u001b[0m in \u001b[0;36mencrypt_encoded\u001b[0;34m(self, encoding, r_value)\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0mencrypted_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEncryptedNumber\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mciphertext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexponent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    190\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mr_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 191\u001b[0;31m             \u001b[0mencrypted_number\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobfuscate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    192\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mencrypted_number\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/atrask/anaconda/lib/python3.6/site-packages/phe-1.3.0-py3.6.egg/phe/paillier.py\u001b[0m in \u001b[0;36mobfuscate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    860\u001b[0m             \u001b[0msend_to_nsa\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mproduct\u001b[0m\u001b[0;34m)\u001b[0m   \u001b[0;31m# NSA can't deduce 2.718 by bruteforce attack\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    861\u001b[0m         \"\"\"\n\u001b[0;32m--> 862\u001b[0;31m         \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_random_lt_n\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    863\u001b[0m         \u001b[0mr_pow_n\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpowmod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnsquare\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    864\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__ciphertext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__ciphertext\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mr_pow_n\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic_key\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnsquare\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/atrask/anaconda/lib/python3.6/site-packages/phe-1.3.0-py3.6.egg/phe/paillier.py\u001b[0m in \u001b[0;36mget_random_lt_n\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    139\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_random_lt_n\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    140\u001b[0m         \u001b[0;34m\"\"\"Return a cryptographically random number less than :attr:`n`\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSystemRandom\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    143\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mencrypt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprecision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/atrask/anaconda/lib/python3.6/random.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     85\u001b[0m     \u001b[0mVERSION\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m     \u001b[0;31m# used by getstate/setstate\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m         \"\"\"Initialize an instance.\n\u001b[1;32m     89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    print(\"\\nStarting Training Round...\")\n",
    "    print(\"\\tStep 1: send the model to Bob\")\n",
    "    bob_encrypted_model = train_and_encrypt(copy.deepcopy(model), \n",
    "                                            bob[0], bob[1], public_key)\n",
    "\n",
    "    print(\"\\n\\tStep 2: send the model to Alice\")\n",
    "    alice_encrypted_model = train_and_encrypt(copy.deepcopy(model), \n",
    "                                              alice[0], alice[1], public_key)\n",
    "\n",
    "    print(\"\\n\\tStep 3: Send the model to Sue\")\n",
    "    sue_encrypted_model = train_and_encrypt(copy.deepcopy(model), \n",
    "                                            sue[0], sue[1], public_key)\n",
    "\n",
    "    print(\"\\n\\tStep 4: Bob, Alice, and Sue send their\")\n",
    "    print(\"\\tencrypted models to each other.\")\n",
    "    aggregated_model = bob_encrypted_model + \\\n",
    "                       alice_encrypted_model + \\\n",
    "                       sue_encrypted_model\n",
    "\n",
    "    print(\"\\n\\tStep 5: only the aggregated model\")\n",
    "    print(\"\\tis sent back to the model owner who\")\n",
    "    print(\"\\t can decrypt it.\")\n",
    "    raw_values = list()\n",
    "    for val in sue_encrypted_model.flatten():\n",
    "        raw_values.append(private_key.decrypt(val))\n",
    "    model.weight.data = np.array(raw_values).reshape(model.weight.data.shape)/3\n",
    "\n",
    "    print(\"\\t% Correct on Test Set: \" + \\\n",
    "              str(test(model, test_data, test_target)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def train_and_encrypt()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
